/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    RandomTree.java
 *    Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.trees;

import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.ContingencyTables;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.Capabilities.Capability;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

/**
 * <!-- globalinfo-start --> Class for constructing a tree that considers K
 * randomly chosen attributes at each node. Performs no pruning. Also has an
 * option to allow estimation of class probabilities based on a hold-out set
 * (backfitting).
 * <p/>
 * <!-- globalinfo-end -->
 * 
 * <!-- options-start --> Valid options are:
 * <p/>
 * 
 * <pre>
 * -K &lt;number of attributes&gt;
 *  Number of attributes to randomly investigate
 *  (&lt;0 = int(log_2(#attributes)+1)).
 * </pre>
 * 
 * <pre>
 * -M &lt;minimum number of instances&gt;
 *  Set minimum number of instances per leaf.
 * </pre>
 * 
 * <pre>
 * -S &lt;num&gt;
 *  Seed for random number generator.
 *  (default 1)
 * </pre>
 * 
 * <pre>
 * -depth &lt;num&gt;
 *  The maximum depth of the tree, 0 for unlimited.
 *  (default 0)
 * </pre>
 * 
 * <pre>
 * -N &lt;num&gt;
 *  Number of folds for backfitting (default 0, no backfitting).
 * </pre>
 * 
 * <pre>
 * -U
 *  Allow unclassified instances.
 * </pre>
 * 
 * <pre>
 * -D
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
 * </pre>
 * 
 * <!-- options-end -->
 * 
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
 * @version $Revision: 5535 $
 */
public class RandomTree extends Classifier implements OptionHandler,
		WeightedInstancesHandler, Randomizable, Drawable {

	/** for serialization */
	static final long serialVersionUID = 8934314652175299374L;

	/** The subtrees appended to this tree. */
	protected RandomTree[] m_Successors;

	/** The attribute to split on. */
	protected int m_Attribute = -1;

	/** The split point. */
	protected double m_SplitPoint = Double.NaN;

	/** The header information. */
	protected Instances m_Info = null;

	/** The proportions of training instances going down each branch. */
	protected double[] m_Prop = null;

	/** Class probabilities from the training data. */
	protected double[] m_ClassDistribution = null;

	/** Minimum number of instances for leaf. */
	protected double m_MinNum = 1.0;

	/** The number of attributes considered for a split. */
	protected int m_KValue = 0;

	/** The random seed to use. */
	protected int m_randomSeed = 1;

	/** The maximum depth of the tree (0 = unlimited) */
	protected int m_MaxDepth = 0;

	/** Determines how much data is used for backfitting */
	protected int m_NumFolds = 0;

	/** Whether unclassified instances are allowed */
	protected boolean m_AllowUnclassifiedInstances = false;

	/** a ZeroR model in case no model can be built from the data */
	protected Classifier m_ZeroR;

	/**
	 * Returns a string describing classifier
	 * 
	 * @return a description suitable for displaying in the
	 *         explorer/experimenter gui
	 */
	public String globalInfo() {

		return "Class for constructing a tree that considers K randomly "
				+ " chosen attributes at each node. Performs no pruning. Also has"
				+ " an option to allow estimation of class probabilities based on"
				+ " a hold-out set (backfitting).";
	}

	/**
	 * Returns the tip text for this property
	 * 
	 * @return tip text for this property suitable for displaying in the
	 *         explorer/experimenter gui
	 */
	public String minNumTipText() {
		return "The minimum total weight of the instances in a leaf.";
	}

	/**
	 * Get the value of MinNum.
	 * 
	 * @return Value of MinNum.
	 */
	public double getMinNum() {

		return m_MinNum;
	}

	/**
	 * Set the value of MinNum.
	 * 
	 * @param newMinNum
	 *            Value to assign to MinNum.
	 */
	public void setMinNum(double newMinNum) {

		m_MinNum = newMinNum;
	}

	/**
	 * Returns the tip text for this property
	 * 
	 * @return tip text for this property suitable for displaying in the
	 *         explorer/experimenter gui
	 */
	public String KValueTipText() {
		return "Sets the number of randomly chosen attributes. If 0, log_2(number_of_attributes) + 1 is used.";
	}

	/**
	 * Get the value of K.
	 * 
	 * @return Value of K.
	 */
	public int getKValue() {

		return m_KValue;
	}

	/**
	 * Set the value of K.
	 * 
	 * @param k
	 *            Value to assign to K.
	 */
	public void setKValue(int k) {

		m_KValue = k;
	}

	/**
	 * Returns the tip text for this property
	 * 
	 * @return tip text for this property suitable for displaying in the
	 *         explorer/experimenter gui
	 */
	public String seedTipText() {
		return "The random number seed used for selecting attributes.";
	}

	/**
	 * Set the seed for random number generation.
	 * 
	 * @param seed
	 *            the seed
	 */
	public void setSeed(int seed) {

		m_randomSeed = seed;
	}

	/**
	 * Gets the seed for the random number generations
	 * 
	 * @return the seed for the random number generation
	 */
	public int getSeed() {

		return m_randomSeed;
	}

	/**
	 * Returns the tip text for this property
	 * 
	 * @return tip text for this property suitable for displaying in the
	 *         explorer/experimenter gui
	 */
	public String maxDepthTipText() {
		return "The maximum depth of the tree, 0 for unlimited.";
	}

	/**
	 * Get the maximum depth of trh tree, 0 for unlimited.
	 * 
	 * @return the maximum depth.
	 */
	public int getMaxDepth() {
		return m_MaxDepth;
	}

	/**
	 * Returns the tip text for this property
	 * 
	 * @return tip text for this property suitable for displaying in the
	 *         explorer/experimenter gui
	 */
	public String numFoldsTipText() {
		return "Determines the amount of data used for backfitting. One fold is used for "
				+ "backfitting, the rest for growing the tree. (Default: 0, no backfitting)";
	}

	/**
	 * Get the value of NumFolds.
	 * 
	 * @return Value of NumFolds.
	 */
	public int getNumFolds() {

		return m_NumFolds;
	}

	/**
	 * Set the value of NumFolds.
	 * 
	 * @param newNumFolds
	 *            Value to assign to NumFolds.
	 */
	public void setNumFolds(int newNumFolds) {

		m_NumFolds = newNumFolds;
	}

	/**
	 * Returns the tip text for this property
	 * 
	 * @return tip text for this property suitable for displaying in the
	 *         explorer/experimenter gui
	 */
	public String allowUnclassifiedInstancesTipText() {
		return "Whether to allow unclassified instances.";
	}

	/**
	 * Get the value of NumFolds.
	 * 
	 * @return Value of NumFolds.
	 */
	public boolean getAllowUnclassifiedInstances() {

		return m_AllowUnclassifiedInstances;
	}

	/**
	 * Set the value of AllowUnclassifiedInstances.
	 * 
	 * @param newAllowUnclassifiedInstances
	 *            Value to assign to AllowUnclassifiedInstances.
	 */
	public void setAllowUnclassifiedInstances(
			boolean newAllowUnclassifiedInstances) {

		m_AllowUnclassifiedInstances = newAllowUnclassifiedInstances;
	}

	/**
	 * Set the maximum depth of the tree, 0 for unlimited.
	 * 
	 * @param value
	 *            the maximum depth.
	 */
	public void setMaxDepth(int value) {
		m_MaxDepth = value;
	}

	/**
	 * Lists the command-line options for this classifier.
	 * 
	 * @return an enumeration over all possible options
	 */
	public Enumeration listOptions() {

		Vector newVector = new Vector();

		newVector.addElement(new Option(
				"\tNumber of attributes to randomly investigate\n"
						+ "\t(<0 = int(log_2(#attributes)+1)).", "K", 1,
				"-K <number of attributes>"));

		newVector.addElement(new Option(
				"\tSet minimum number of instances per leaf.", "M", 1,
				"-M <minimum number of instances>"));

		newVector.addElement(new Option("\tSeed for random number generator.\n"
				+ "\t(default 1)", "S", 1, "-S <num>"));

		newVector.addElement(new Option(
				"\tThe maximum depth of the tree, 0 for unlimited.\n"
						+ "\t(default 0)", "depth", 1, "-depth <num>"));

		newVector.addElement(new Option("\tNumber of folds for backfitting "
				+ "(default 0, no backfitting).", "N", 1, "-N <num>"));
		newVector.addElement(new Option("\tAllow unclassified instances.", "U",
				0, "-U"));

		Enumeration enu = super.listOptions();
		while (enu.hasMoreElements()) {
			newVector.addElement(enu.nextElement());
		}

		return newVector.elements();
	}

	/**
	 * Gets options from this classifier.
	 * 
	 * @return the options for the current setup
	 */
	public String[] getOptions() {
		Vector result;
		String[] options;
		int i;

		result = new Vector();

		result.add("-K");
		result.add("" + getKValue());

		result.add("-M");
		result.add("" + getMinNum());

		result.add("-S");
		result.add("" + getSeed());

		if (getMaxDepth() > 0) {
			result.add("-depth");
			result.add("" + getMaxDepth());
		}

		if (getNumFolds() > 0) {
			result.add("-N");
			result.add("" + getNumFolds());
		}

		if (getAllowUnclassifiedInstances()) {
			result.add("-U");
		}

		options = super.getOptions();
		for (i = 0; i < options.length; i++)
			result.add(options[i]);

		return (String[]) result.toArray(new String[result.size()]);
	}

	/**
	 * Parses a given list of options.
	 * <p/>
	 * 
	 * <!-- options-start --> Valid options are:
	 * <p/>
	 * 
	 * <pre>
	 * -K &lt;number of attributes&gt;
	 *  Number of attributes to randomly investigate
	 *  (&lt;0 = int(log_2(#attributes)+1)).
	 * </pre>
	 * 
	 * <pre>
	 * -M &lt;minimum number of instances&gt;
	 *  Set minimum number of instances per leaf.
	 * </pre>
	 * 
	 * <pre>
	 * -S &lt;num&gt;
	 *  Seed for random number generator.
	 *  (default 1)
	 * </pre>
	 * 
	 * <pre>
	 * -depth &lt;num&gt;
	 *  The maximum depth of the tree, 0 for unlimited.
	 *  (default 0)
	 * </pre>
	 * 
	 * <pre>
	 * -N &lt;num&gt;
	 *  Number of folds for backfitting (default 0, no backfitting).
	 * </pre>
	 * 
	 * <pre>
	 * -U
	 *  Allow unclassified instances.
	 * </pre>
	 * 
	 * <pre>
	 * -D
	 *  If set, classifier is run in debug mode and
	 *  may output additional info to the console
	 * </pre>
	 * 
	 * <!-- options-end -->
	 * 
	 * @param options
	 *            the list of options as an array of strings
	 * @throws Exception
	 *             if an option is not supported
	 */
	public void setOptions(String[] options) throws Exception {
		String tmpStr;

		tmpStr = Utils.getOption('K', options);
		if (tmpStr.length() != 0) {
			m_KValue = Integer.parseInt(tmpStr);
		} else {
			m_KValue = 0;
		}

		tmpStr = Utils.getOption('M', options);
		if (tmpStr.length() != 0) {
			m_MinNum = Double.parseDouble(tmpStr);
		} else {
			m_MinNum = 1;
		}

		tmpStr = Utils.getOption('S', options);
		if (tmpStr.length() != 0) {
			setSeed(Integer.parseInt(tmpStr));
		} else {
			setSeed(1);
		}

		tmpStr = Utils.getOption("depth", options);
		if (tmpStr.length() != 0) {
			setMaxDepth(Integer.parseInt(tmpStr));
		} else {
			setMaxDepth(0);
		}
		String numFoldsString = Utils.getOption('N', options);
		if (numFoldsString.length() != 0) {
			m_NumFolds = Integer.parseInt(numFoldsString);
		} else {
			m_NumFolds = 0;
		}

		setAllowUnclassifiedInstances(Utils.getFlag('U', options));

		super.setOptions(options);

		Utils.checkForRemainingOptions(options);
	}

	/**
	 * Returns default capabilities of the classifier.
	 * 
	 * @return the capabilities of this classifier
	 */
	public Capabilities getCapabilities() {
		Capabilities result = super.getCapabilities();
		result.disableAll();

		// attributes
		result.enable(Capability.NOMINAL_ATTRIBUTES);
		result.enable(Capability.NUMERIC_ATTRIBUTES);
		result.enable(Capability.DATE_ATTRIBUTES);
		result.enable(Capability.MISSING_VALUES);

		// class
		result.enable(Capability.NOMINAL_CLASS);
		result.enable(Capability.MISSING_CLASS_VALUES);

		return result;
	}

	/**
	 * Builds classifier.
	 * 
	 * @param data
	 *            the data to train with
	 * @throws Exception
	 *             if something goes wrong or the data doesn't fit
	 */
	public void buildClassifier(Instances data) throws Exception {

		// Make sure K value is in range
		if (m_KValue > data.numAttributes() - 1)
			m_KValue = data.numAttributes() - 1;
		if (m_KValue < 1)
			m_KValue = (int) Utils.log2(data.numAttributes()) + 1;

		// can classifier handle the data?
		getCapabilities().testWithFail(data);

		// remove instances with missing class
		data = new Instances(data);
		data.deleteWithMissingClass();

		// only class? -> build ZeroR model
		if (data.numAttributes() == 1) {
			System.err
					.println("Cannot build model (only class attribute present in data!), "
							+ "using ZeroR model instead!");
			m_ZeroR = new weka.classifiers.rules.ZeroR();
			m_ZeroR.buildClassifier(data);
			return;
		} else {
			m_ZeroR = null;
		}

		// Figure out appropriate datasets
		Instances train = null;
		Instances backfit = null;
		Random rand = data.getRandomNumberGenerator(m_randomSeed);
		if (m_NumFolds <= 0) {
			train = data;
		} else {
			data.randomize(rand);
			data.stratify(m_NumFolds);
			train = data.trainCV(m_NumFolds, 1, rand);
			backfit = data.testCV(m_NumFolds, 1);
		}

		// Create the attribute indices window
		int[] attIndicesWindow = new int[data.numAttributes() - 1];
		int j = 0;
		for (int i = 0; i < attIndicesWindow.length; i++) {
			if (j == data.classIndex())
				j++; // do not include the class
			attIndicesWindow[i] = j++;
		}

		// Compute initial class counts
		double[] classProbs = new double[train.numClasses()];
		for (int i = 0; i < train.numInstances(); i++) {
			Instance inst = train.instance(i);
			classProbs[(int) inst.classValue()] += inst.weight();
		}

		// Build tree
		buildTree(train, classProbs, new Instances(data, 0), m_MinNum, m_Debug,
				attIndicesWindow, rand, 0, getAllowUnclassifiedInstances());

		// Backfit if required
		if (backfit != null) {
			backfitData(backfit);
		}
	}

	/**
	 * Backfits the given data into the tree.
	 */
	public void backfitData(Instances data) throws Exception {

		// Compute initial class counts
		double[] classProbs = new double[data.numClasses()];
		for (int i = 0; i < data.numInstances(); i++) {
			Instance inst = data.instance(i);
			classProbs[(int) inst.classValue()] += inst.weight();
		}

		// Fit data into tree
		backfitData(data, classProbs);
	}

	/**
	 * Computes class distribution of an instance using the decision tree.
	 * 
	 * @param instance
	 *            the instance to compute the distribution for
	 * @return the computed class distribution
	 * @throws Exception
	 *             if computation fails
	 */
	public double[] distributionForInstance(Instance instance) throws Exception {

		// default model?
		if (m_ZeroR != null) {
			return m_ZeroR.distributionForInstance(instance);
		}

		double[] returnedDist = null;

		if (m_Attribute > -1) {

			// Node is not a leaf
			if (instance.isMissing(m_Attribute)) {

				// Value is missing
				returnedDist = new double[m_Info.numClasses()];

				// Split instance up
				for (int i = 0; i < m_Successors.length; i++) {
					double[] help = m_Successors[i]
							.distributionForInstance(instance);
					if (help != null) {
						for (int j = 0; j < help.length; j++) {
							returnedDist[j] += m_Prop[i] * help[j];
						}
					}
				}
			} else if (m_Info.attribute(m_Attribute).isNominal()) {

				// For nominal attributes
				returnedDist = m_Successors[(int) instance.value(m_Attribute)]
						.distributionForInstance(instance);
			} else {

				// For numeric attributes
				if (instance.value(m_Attribute) < m_SplitPoint) {
					returnedDist = m_Successors[0]
							.distributionForInstance(instance);
				} else {
					returnedDist = m_Successors[1]
							.distributionForInstance(instance);
				}
			}
		}

		// Node is a leaf or successor is empty?
		if ((m_Attribute == -1) || (returnedDist == null)) {

			// Is node empty?
			if (m_ClassDistribution == null) {
				if (getAllowUnclassifiedInstances()) {
					return new double[m_Info.numClasses()];
				} else {
					return null;
				}
			}

			// Else return normalized distribution
			double[] normalizedDistribution = (double[]) m_ClassDistribution
					.clone();
			Utils.normalize(normalizedDistribution);
			return normalizedDistribution;
		} else {
			return returnedDist;
		}
	}

	/**
	 * Outputs the decision tree as a graph
	 * 
	 * @return the tree as a graph
	 */
	public String toGraph() {

		try {
			StringBuffer resultBuff = new StringBuffer();
			toGraph(resultBuff, 0);
			String result = "digraph Tree {\n" + "edge [style=bold]\n"
					+ resultBuff.toString() + "\n}\n";
			return result;
		} catch (Exception e) {
			return null;
		}
	}

	/**
	 * Outputs one node for graph.
	 * 
	 * @param text
	 *            the buffer to append the output to
	 * @param num
	 *            unique node id
	 * @return the next node id
	 * @throws Exception
	 *             if generation fails
	 */
	public int toGraph(StringBuffer text, int num) throws Exception {

		int maxIndex = Utils.maxIndex(m_ClassDistribution);
		String classValue = m_Info.classAttribute().value(maxIndex);

		num++;
		if (m_Attribute == -1) {
			text.append("N" + Integer.toHexString(hashCode()) + " [label=\""
					+ num + ": " + classValue + "\"" + "shape=box]\n");
		} else {
			text.append("N" + Integer.toHexString(hashCode()) + " [label=\""
					+ num + ": " + classValue + "\"]\n");
			for (int i = 0; i < m_Successors.length; i++) {
				text.append("N" + Integer.toHexString(hashCode()) + "->" + "N"
						+ Integer.toHexString(m_Successors[i].hashCode())
						+ " [label=\"" + m_Info.attribute(m_Attribute).name());
				if (m_Info.attribute(m_Attribute).isNumeric()) {
					if (i == 0) {
						text.append(" < "
								+ Utils.doubleToString(m_SplitPoint, 2));
					} else {
						text.append(" >= "
								+ Utils.doubleToString(m_SplitPoint, 2));
					}
				} else {
					text.append(" = " + m_Info.attribute(m_Attribute).value(i));
				}
				text.append("\"]\n");
				num = m_Successors[i].toGraph(text, num);
			}
		}

		return num;
	}

	/**
	 * Outputs the decision tree.
	 * 
	 * @return a string representation of the classifier
	 */
	public String toString() {

		// only ZeroR model?
		if (m_ZeroR != null) {
			StringBuffer buf = new StringBuffer();
			buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
			buf.append(this.getClass().getName().replaceAll(".*\\.", "")
					.replaceAll(".", "=")
					+ "\n\n");
			buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
			buf.append(m_ZeroR.toString());
			return buf.toString();
		}

		if (m_Successors == null) {
			return "RandomTree: no model has been built yet.";
		} else {
			return "\nRandomTree\n==========\n"
					+ toString(0)
					+ "\n"
					+ "\nSize of the tree : "
					+ numNodes()
					+ (getMaxDepth() > 0 ? ("\nMax depth of tree: " + getMaxDepth())
							: (""));
		}
	}

	/**
	 * Outputs a leaf.
	 * 
	 * @return the leaf as string
	 * @throws Exception
	 *             if generation fails
	 */
	protected String leafString() throws Exception {

		double sum = 0, maxCount = 0;
		int maxIndex = 0;
		if (m_ClassDistribution != null) {
			sum = Utils.sum(m_ClassDistribution);
			maxIndex = Utils.maxIndex(m_ClassDistribution);
			maxCount = m_ClassDistribution[maxIndex];
		}
		return " : " + m_Info.classAttribute().value(maxIndex) + " ("
				+ Utils.doubleToString(sum, 2) + "/"
				+ Utils.doubleToString(sum - maxCount, 2) + ")";
	}

	/**
	 * Recursively outputs the tree.
	 * 
	 * @param level
	 *            the current level of the tree
	 * @return the generated subtree
	 */
	protected String toString(int level) {

		try {
			StringBuffer text = new StringBuffer();

			if (m_Attribute == -1) {

				// Output leaf info
				return leafString();
			} else if (m_Info.attribute(m_Attribute).isNominal()) {

				// For nominal attributes
				for (int i = 0; i < m_Successors.length; i++) {
					text.append("\n");
					for (int j = 0; j < level; j++) {
						text.append("|   ");
					}
					text.append(m_Info.attribute(m_Attribute).name() + " = "
							+ m_Info.attribute(m_Attribute).value(i));
					text.append(m_Successors[i].toString(level + 1));
				}
			} else {

				// For numeric attributes
				text.append("\n");
				for (int j = 0; j < level; j++) {
					text.append("|   ");
				}
				text.append(m_Info.attribute(m_Attribute).name() + " < "
						+ Utils.doubleToString(m_SplitPoint, 2));
				text.append(m_Successors[0].toString(level + 1));
				text.append("\n");
				for (int j = 0; j < level; j++) {
					text.append("|   ");
				}
				text.append(m_Info.attribute(m_Attribute).name() + " >= "
						+ Utils.doubleToString(m_SplitPoint, 2));
				text.append(m_Successors[1].toString(level + 1));
			}

			return text.toString();
		} catch (Exception e) {
			e.printStackTrace();
			return "RandomTree: tree can't be printed";
		}
	}

	/**
	 * Recursively backfits data into the tree.
	 * 
	 * @param data
	 *            the data to work with
	 * @param classProbs
	 *            the class distribution
	 * @throws Exception
	 *             if generation fails
	 */
	protected void backfitData(Instances data, double[] classProbs)
			throws Exception {

		// Make leaf if there are no training instances
		if (data.numInstances() == 0) {
			m_Attribute = -1;
			m_ClassDistribution = null;
			m_Prop = null;
			return;
		}

		// Check if node doesn't contain enough instances or is pure
		// or maximum depth reached
		m_ClassDistribution = (double[]) classProbs.clone();

		/*
		 * if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum ||
		 * Utils.eq(m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)],
		 * Utils .sum(m_ClassDistribution))) {
		 * 
		 * // Make leaf m_Attribute = -1; m_Prop = null; return; }
		 */

		// Are we at an inner node
		if (m_Attribute > -1) {

			// Compute new weights for subsets based on backfit data
			m_Prop = new double[m_Successors.length];
			for (int i = 0; i < data.numInstances(); i++) {
				Instance inst = data.instance(i);
				if (!inst.isMissing(m_Attribute)) {
					if (data.attribute(m_Attribute).isNominal()) {
						m_Prop[(int) inst.value(m_Attribute)] += inst.weight();
					} else {
						m_Prop[(inst.value(m_Attribute) < m_SplitPoint) ? 0 : 1] += inst
								.weight();
					}
				}
			}

			// If we only have missing values we can make this node into a leaf
			if (Utils.sum(m_Prop) <= 0) {
				m_Attribute = -1;
				m_Prop = null;
				return;
			}

			// Otherwise normalize the proportions
			Utils.normalize(m_Prop);

			// Split data
			Instances[] subsets = splitData(data);

			// Go through subsets
			for (int i = 0; i < subsets.length; i++) {

				// Compute distribution for current subset
				double[] dist = new double[data.numClasses()];
				for (int j = 0; j < subsets[i].numInstances(); j++) {
					dist[(int) subsets[i].instance(j).classValue()] += subsets[i]
							.instance(j).weight();
				}

				// Backfit subset
				m_Successors[i].backfitData(subsets[i], dist);
			}

			// If unclassified instances are allowed, we don't need to store the
			// class distribution
			if (getAllowUnclassifiedInstances()) {
				m_ClassDistribution = null;
				return;
			}

			// Otherwise, if all successors are non-empty, we don't need to
			// store the class distribution
			boolean emptySuccessor = false;
			for (int i = 0; i < subsets.length; i++) {
				if (m_Successors[i].m_ClassDistribution == null) {
					emptySuccessor = true;
					return;
				}
			}
			m_ClassDistribution = null;

			// If we have a least two non-empty successors, we should keep this
			// tree
			/*
			 * int nonEmptySuccessors = 0; for (int i = 0; i < subsets.length;
			 * i++) { if (m_Successors[i].m_ClassDistribution != null) {
			 * nonEmptySuccessors++; if (nonEmptySuccessors > 1) { return; } } }
			 * 
			 * // Otherwise, this node is a leaf or should become a leaf
			 * m_Successors = null; m_Attribute = -1; m_Prop = null; return;
			 */
		}
	}

	/**
	 * Recursively generates a tree.
	 * 
	 * @param data
	 *            the data to work with
	 * @param classProbs
	 *            the class distribution
	 * @param header
	 *            the header of the data
	 * @param minNum
	 *            the minimum number of instances per leaf
	 * @param debug
	 *            whether debugging is on
	 * @param attIndicesWindow
	 *            the attribute window to choose attributes from
	 * @param random
	 *            random number generator for choosing random attributes
	 * @param depth
	 *            the current depth
	 * @param determineStructure
	 *            whether to determine structure
	 * @throws Exception
	 *             if generation fails
	 */
	protected void buildTree(Instances data, double[] classProbs,
			Instances header, double minNum, boolean debug,
			int[] attIndicesWindow, Random random, int depth, boolean allow)
			throws Exception {

		// Store structure of dataset, set minimum number of instances
		m_Info = header;
		m_Debug = debug;
		m_MinNum = minNum;
		m_AllowUnclassifiedInstances = allow;

		// Make leaf if there are no training instances
		if (data.numInstances() == 0) {
			m_Attribute = -1;
			m_ClassDistribution = null;
			m_Prop = null;
			return;
		}

		// Check if node doesn't contain enough instances or is pure
		// or maximum depth reached
		m_ClassDistribution = (double[]) classProbs.clone();

		if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum
				|| Utils.eq(
						m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)],
						Utils.sum(m_ClassDistribution))
				|| ((getMaxDepth() > 0) && (depth >= getMaxDepth()))) {
			// Make leaf
			m_Attribute = -1;
			m_Prop = null;
			return;
		}

		// Compute class distributions and value of splitting
		// criterion for each attribute
		double[] vals = new double[data.numAttributes()];
		double[][][] dists = new double[data.numAttributes()][0][0];
		double[][] props = new double[data.numAttributes()][0];
		double[] splits = new double[data.numAttributes()];

		// Investigate K random attributes
		int attIndex = 0;
		int windowSize = attIndicesWindow.length;
		int k = m_KValue;
		boolean gainFound = false;
		while ((windowSize > 0) && (k-- > 0 || !gainFound)) {

			int chosenIndex = random.nextInt(windowSize);
			attIndex = attIndicesWindow[chosenIndex];

			// shift chosen attIndex out of window
			attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize - 1];
			attIndicesWindow[windowSize - 1] = attIndex;
			windowSize--;

			splits[attIndex] = distribution(props, dists, attIndex, data);
			vals[attIndex] = gain(dists[attIndex], priorVal(dists[attIndex]));

			if (Utils.gr(vals[attIndex], 0))
				gainFound = true;
		}

		// Find best attribute
		m_Attribute = Utils.maxIndex(vals);
		double[][] distribution = dists[m_Attribute];

		// Any useful split found?
		if (Utils.gr(vals[m_Attribute], 0)) {

			// Build subtrees
			m_SplitPoint = splits[m_Attribute];
			m_Prop = props[m_Attribute];
			Instances[] subsets = splitData(data);
			m_Successors = new RandomTree[distribution.length];
			for (int i = 0; i < distribution.length; i++) {
				m_Successors[i] = new RandomTree();
				m_Successors[i].setKValue(m_KValue);
				m_Successors[i].setMaxDepth(getMaxDepth());
				m_Successors[i].buildTree(subsets[i], distribution[i], header,
						m_MinNum, m_Debug, attIndicesWindow, random, depth + 1,
						allow);
			}

			// If all successors are non-empty, we don't need to store the class
			// distribution
			boolean emptySuccessor = false;
			for (int i = 0; i < subsets.length; i++) {
				if (m_Successors[i].m_ClassDistribution == null) {
					emptySuccessor = true;
					break;
				}
			}
			if (!emptySuccessor) {
				m_ClassDistribution = null;
			}
		} else {

			// Make leaf
			m_Attribute = -1;
		}
	}

	/**
	 * Computes size of the tree.
	 * 
	 * @return the number of nodes
	 */
	public int numNodes() {

		if (m_Attribute == -1) {
			return 1;
		} else {
			int size = 1;
			for (int i = 0; i < m_Successors.length; i++) {
				size += m_Successors[i].numNodes();
			}
			return size;
		}
	}

	/**
	 * Splits instances into subsets based on the given split.
	 * 
	 * @param data
	 *            the data to work with
	 * @return the subsets of instances
	 * @throws Exception
	 *             if something goes wrong
	 */
	protected Instances[] splitData(Instances data) throws Exception {

		// Allocate array of Instances objects
		Instances[] subsets = new Instances[m_Prop.length];
		for (int i = 0; i < m_Prop.length; i++) {
			subsets[i] = new Instances(data, data.numInstances());
		}

		// Go through the data
		for (int i = 0; i < data.numInstances(); i++) {

			// Get instance
			Instance inst = data.instance(i);

			// Does the instance have a missing value?
			if (inst.isMissing(m_Attribute)) {

				// Split instance up
				for (int k = 0; k < m_Prop.length; k++) {
					if (m_Prop[k] > 0) {
						Instance copy = (Instance) inst.copy();
						copy.setWeight(m_Prop[k] * inst.weight());
						subsets[k].add(copy);
					}
				}

				// Proceed to next instance
				continue;
			}

			// Do we have a nominal attribute?
			if (data.attribute(m_Attribute).isNominal()) {
				subsets[(int) inst.value(m_Attribute)].add(inst);

				// Proceed to next instance
				continue;
			}

			// Do we have a numeric attribute?
			if (data.attribute(m_Attribute).isNumeric()) {
				subsets[(inst.value(m_Attribute) < m_SplitPoint) ? 0 : 1]
						.add(inst);

				// Proceed to next instance
				continue;
			}

			// Else throw an exception
			throw new IllegalArgumentException("Unknown attribute type");
		}

		// Save memory
		for (int i = 0; i < m_Prop.length; i++) {
			subsets[i].compactify();
		}

		// Return the subsets
		return subsets;
	}

	/**
	 * Computes class distribution for an attribute.
	 * 
	 * @param props
	 * @param dists
	 * @param att
	 *            the attribute index
	 * @param data
	 *            the data to work with
	 * @throws Exception
	 *             if something goes wrong
	 */
	protected double distribution(double[][] props, double[][][] dists,
			int att, Instances data) throws Exception {

		double splitPoint = Double.NaN;
		Attribute attribute = data.attribute(att);
		double[][] dist = null;
		int indexOfFirstMissingValue = -1;

		if (attribute.isNominal()) {

			// For nominal attributes
			dist = new double[attribute.numValues()][data.numClasses()];
			for (int i = 0; i < data.numInstances(); i++) {
				Instance inst = data.instance(i);
				if (inst.isMissing(att)) {

					// Skip missing values at this stage
					if (indexOfFirstMissingValue < 0) {
						indexOfFirstMissingValue = i;
					}
					continue;
				}
				dist[(int) inst.value(att)][(int) inst.classValue()] += inst
						.weight();
			}
		} else {

			// For numeric attributes
			double[][] currDist = new double[2][data.numClasses()];
			dist = new double[2][data.numClasses()];

			// Sort data
			data.sort(att);

			// Move all instances into second subset
			for (int j = 0; j < data.numInstances(); j++) {
				Instance inst = data.instance(j);
				if (inst.isMissing(att)) {

					// Can stop as soon as we hit a missing value
					indexOfFirstMissingValue = j;
					break;
				}
				currDist[1][(int) inst.classValue()] += inst.weight();
			}

			// Value before splitting
			double priorVal = priorVal(currDist);

			// Save initial distribution
			for (int j = 0; j < currDist.length; j++) {
				System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length);
			}

			// Try all possible split points
			double currSplit = data.instance(0).value(att);
			double currVal, bestVal = -Double.MAX_VALUE;
			for (int i = 0; i < data.numInstances(); i++) {
				Instance inst = data.instance(i);
				if (inst.isMissing(att)) {

					// Can stop as soon as we hit a missing value
					break;
				}

				// Can we place a sensible split point here?
				if (inst.value(att) > currSplit) {

					// Compute gain for split point
					currVal = gain(currDist, priorVal);

					// Is the current split point the best point so far?
					if (currVal > bestVal) {

						// Store value of current point
						bestVal = currVal;

						// Save split point
						splitPoint = (inst.value(att) + currSplit) / 2.0;

						// Save distribution
						for (int j = 0; j < currDist.length; j++) {
							System.arraycopy(currDist[j], 0, dist[j], 0,
									dist[j].length);
						}
					}
				}
				currSplit = inst.value(att);

				// Shift over the weight
				currDist[0][(int) inst.classValue()] += inst.weight();
				currDist[1][(int) inst.classValue()] -= inst.weight();
			}
		}

		// Compute weights for subsets
		props[att] = new double[dist.length];
		for (int k = 0; k < props[att].length; k++) {
			props[att][k] = Utils.sum(dist[k]);
		}
		if (Utils.eq(Utils.sum(props[att]), 0)) {
			for (int k = 0; k < props[att].length; k++) {
				props[att][k] = 1.0 / (double) props[att].length;
			}
		} else {
			Utils.normalize(props[att]);
		}

		// Any instances with missing values ?
		if (indexOfFirstMissingValue > -1) {

			// Distribute weights for instances with missing values
			for (int i = indexOfFirstMissingValue; i < data.numInstances(); i++) {
				Instance inst = data.instance(i);
				if (attribute.isNominal()) {

					// Need to check if attribute value is missing
					if (inst.isMissing(att)) {
						for (int j = 0; j < dist.length; j++) {
							dist[j][(int) inst.classValue()] += props[att][j]
									* inst.weight();
						}
					}
				} else {

					// Can be sure that value is missing, so no test required
					for (int j = 0; j < dist.length; j++) {
						dist[j][(int) inst.classValue()] += props[att][j]
								* inst.weight();
					}
				}
			}
		}

		// Return distribution and split point
		dists[att] = dist;
		return splitPoint;
	}

	/**
	 * Computes value of splitting criterion before split.
	 * 
	 * @param dist
	 *            the distributions
	 * @return the splitting criterion
	 */
	protected double priorVal(double[][] dist) {

		return ContingencyTables.entropyOverColumns(dist);
	}

	/**
	 * Computes value of splitting criterion after split.
	 * 
	 * @param dist
	 *            the distributions
	 * @param priorVal
	 *            the splitting criterion
	 * @return the gain after the split
	 */
	protected double gain(double[][] dist, double priorVal) {

		return priorVal - ContingencyTables.entropyConditionedOnRows(dist);
	}

	/**
	 * Returns the revision string.
	 * 
	 * @return the revision
	 */
	public String getRevision() {
		return RevisionUtils.extract("$Revision: 5535 $");
	}

	/**
	 * Main method for this class.
	 * 
	 * @param argv
	 *            the commandline parameters
	 */
	public static void main(String[] argv) {
		runClassifier(new RandomTree(), argv);
	}

	/**
	 * Returns graph describing the tree.
	 * 
	 * @return the graph describing the tree
	 * @throws Exception
	 *             if graph can't be computed
	 */
	public String graph() throws Exception {

		if (m_Successors == null) {
			throw new Exception("RandomTree: No model built yet.");
		}
		StringBuffer resultBuff = new StringBuffer();
		toGraph(resultBuff, 0, null);
		String result = "digraph RandomTree {\n" + "edge [style=bold]\n"
				+ resultBuff.toString() + "\n}\n";
		return result;
	}

	/**
	 * Returns the type of graph this classifier represents.
	 * 
	 * @return Drawable.TREE
	 */
	public int graphType() {
		return Drawable.TREE;
	}

	/**
	 * Outputs one node for graph.
	 * 
	 * @param text
	 *            the buffer to append the output to
	 * @param num
	 *            the current node id
	 * @param parent
	 *            the parent of the nodes
	 * @return the next node id
	 * @throws Exception
	 *             if something goes wrong
	 */
	protected int toGraph(StringBuffer text, int num, RandomTree parent)
			throws Exception {

		num++;
		if (m_Attribute == -1) {
			text.append("N" + Integer.toHexString(RandomTree.this.hashCode())
					+ " [label=\"" + num + leafString() + "\""
					+ " shape=box]\n");

		} else {
			text.append("N" + Integer.toHexString(RandomTree.this.hashCode())
					+ " [label=\"" + num + ": "
					+ m_Info.attribute(m_Attribute).name() + "\"]\n");
			for (int i = 0; i < m_Successors.length; i++) {
				text.append("N"
						+ Integer.toHexString(RandomTree.this.hashCode())
						+ "->" + "N"
						+ Integer.toHexString(m_Successors[i].hashCode())
						+ " [label=\"");
				if (m_Info.attribute(m_Attribute).isNumeric()) {
					if (i == 0) {
						text.append(" < "
								+ Utils.doubleToString(m_SplitPoint, 2));
					} else {
						text.append(" >= "
								+ Utils.doubleToString(m_SplitPoint, 2));
					}
				} else {
					text.append(" = " + m_Info.attribute(m_Attribute).value(i));
				}
				text.append("\"]\n");
				num = m_Successors[i].toGraph(text, num, this);
			}
		}

		return num;
	}
}
